﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
using MathNet.Numerics;
using MachineLearning;

namespace HLDA
{
    public class HldaTopicFixedTree : HldaTopic
    {
        public HldaTopicFixedTree()
        {
        }

        public HldaTopicFixedTree(HldaTopicFixedTree parent)
            : base(parent)
        {
        }

        public override HldaTopic SampleChild()
        {
            return children[Global.random.Next(0, 2)];
        }

        public override void CalculatePathToLeaves(int[,] f)
        {
            int V = Global.vocabSize;
            double eta = Global.eta;
            double Veta = V * eta;
            double a = this.wordCount.Sum();
            double b, c, d;
            b = c = d = 0;
            double weight;
            int t = this.level;
            for (int v = 0; v < V; v++)
            {
                if (f[t, v] > 0)
                {
                    b += SpecialFunctions.GammaLn(eta + wordCount[v]);
                    c += SpecialFunctions.GammaLn(eta + wordCount[v] + f[t, v]);
                }
                d += f[t, v];
            }
            weight = SpecialFunctions.GammaLn(Veta + a) - b + c - SpecialFunctions.GammaLn(Veta + a + d);
            if (IsRoot())
            {
                ncrp = 0;
                weights = weight;
            }
            else
            {
                ncrp = parent.ncrp + Math.Log((customers + (Global.gamma / 2)) / (parent.customers + Global.gamma));
                weights = parent.weights + weight;
            }
            foreach (HldaTopic child in children)
            {
                child.CalculatePathToLeaves(f);
            }
        }

        public void Split(int depth)
        {
            if (depth <= 1)
            {
                return;
            }
            //add two new child
            HldaTopicFixedTree child = new HldaTopicFixedTree(this);
            child.Split(depth - 1);
            child = new HldaTopicFixedTree(this);
            child.Split(depth - 1);
        }
    }

    public class HldaTopic 
    {
        public string name { get; set; }
        public int level { get; set; }
        public HldaTopic parent { get; set; }
        public int customers { get; set; }
        public List<HldaTopic> children { get; set; }
        public int[] wordCount { get; set; }

        public double ncrp { get; set; }
        public double weights { get; set; }//temporary storage for NCRP

        public HldaTopic()
        {
            level = 0;
            customers = 0;
            children = new List<HldaTopic>();
            parent = null; //root node
            wordCount = new int[Global.vocabSize];
            for (int v = 0; v < Global.vocabSize; v++)
            {
                wordCount[v] = 0;
            }
            name = "";
        }

        public bool IsRoot()
        {
            return level == 0;
        }

        public HldaTopic(HldaTopic parent) : this()
        {
            level = parent.level + 1;
            this.parent = parent;
            name = string.Format("{0}.{1}", parent.name, parent.children.Count);
            parent.children.Add(this);
        }

        public string DisplayName
        {
            get { return string.Format("{0}{1}", level, name); }
        }

        public void Print(string sub)
        {
            Console.WriteLine("Level:{0}{1}\tChildren:{2}\tCustomers:{3}", level, sub, children.Count, customers);
            for (int i = 0; i < children.Count; i++)
            {
                children[i].Print(sub + "." + i);
            }
        }

        public void PrintTopWords(string sub, List<string> vocabularyIndex)
        {
            double eta = Global.eta;
            double Veta = Global.eta * Global.vocabSize;
            double sum = wordCount.Sum();
            Console.WriteLine("{0}{1}\t{2}\t{3}", level, sub, children.Count, customers);
            Result[] results = new Result[Global.vocabSize];
            for (int v = 0; v < Global.vocabSize; v++)
            {
                double p = Math.Log(wordCount[v] + eta) - Math.Log(sum + Veta);
                results[v] = new Result(p, vocabularyIndex[v]);
            }
            Array.Sort(results);
            for (int i = 0; i < 50; i++)
            {
                Console.WriteLine("\t\t\t{0:0.000}\t{1}", results[i].Prob, results[i].Word);
            }
            for (int i = 0; i < children.Count; i++)
            {
                children[i].PrintTopWords(sub + "." + i, vocabularyIndex);
            }
        }

        public void WriteTopWords(string sub, List<string> vocabularyIndex, StreamWriter sw)
        {
            double eta = Global.eta;
            double Veta = Global.eta * Global.vocabSize;
            double sum = wordCount.Sum();
            sw.WriteLine("{0}{1}\t{2}\t{3}", level, sub, children.Count, customers);
            Result[] results = new Result[Global.vocabSize];
            for (int v = 0; v < Global.vocabSize; v++)
            {
                double p = Math.Log(wordCount[v] + eta) - Math.Log(sum + Veta);
                results[v] = new Result(p, vocabularyIndex[v]);
            }
            Array.Sort(results);
            for (int i = 0; i < 25; i++)
            {
                sw.WriteLine("\t\t\t{0:0.000}\t{1}", results[i].Prob, results[i].Word);
            }
            for (int i = 0; i < children.Count; i++)
            {
                children[i].WriteTopWords(sub + "." + i, vocabularyIndex, sw);
            }
        }

        public int CountTree()
        {
            int count = 0;
            foreach (HldaTopic child in children)
            {
                count += child.CountTree();
            }
            return 1 + count;
        }

        public virtual HldaTopic SampleChild()
        {
            double tmp = Global.random.NextDouble();
            double[] p = new double[children.Count + 1];
            double[] cumu = new double[children.Count + 1];
            int i;
            for (i = 0; i < p.Length - 1; i++)
            {
                p[i] = (children[i].customers) / (customers - 1 + Global.gamma);
            }
            p[i] = Global.gamma / (customers - 1 + Global.gamma);
            i = Sampling.Sample(p, tmp);
            if (i == children.Count)
            {
                return new HldaTopic(this);
            }
            else
            {
                return children[i];
            }
        }

        public void Remove()
        {
            parent.children.Remove(this);
            parent = null;
        }

        public virtual void CalculatePathToLeaves(int[,] f)
        {
            int V = Global.vocabSize;
            double eta = Global.eta;
            double Veta = V * eta;
            double a = this.wordCount.Sum();
            double b, c, d;
            b = c = d = 0;
            double weight;
            int t = this.level;
            for (int v = 0; v < V; v++)
            {
                if (f[t, v] > 0)
                {
                    b += SpecialFunctions.GammaLn(eta + wordCount[v]);
                    c += SpecialFunctions.GammaLn(eta + wordCount[v] + f[t, v]);
                }
                d += f[t, v];
            }
            weight = SpecialFunctions.GammaLn(Veta + a) - b + c - SpecialFunctions.GammaLn(Veta + a + d);
            if (IsRoot())
            {
                ncrp = 0;
                weights = weight;
            }
            else
            {
                ncrp = parent.ncrp + Math.Log(customers / (parent.customers + Global.gamma));
                weights = parent.weights + weight;
            }
            foreach (HldaTopic child in children)
            {
                child.CalculatePathToLeaves(f);
            }
        }

        public void CalculatePathToInternalNodes(int[,] f)
        {
            int V = Global.vocabSize;
            double eta = Global.eta;
            double Veta = V * eta;
            double b, c, d;
            double logGammaEta = SpecialFunctions.GammaLn(eta);
            for (int t = level + 1; t < Global.maxLevel; t++)
            {
                b = c = d = 0;
                for (int v = 0; v < V; v++)
                {
                    if (f[t, v] > 0)
                    {
                        b++;
                        c += SpecialFunctions.GammaLn(eta + f[t, v]);
                    }
                    d += f[t, v];
                }
                b = b * logGammaEta;
                weights += SpecialFunctions.GammaLn(Veta) - b + c - SpecialFunctions.GammaLn(Veta + d);
            }
            ncrp += Math.Log(Global.gamma / (customers + Global.gamma));
            foreach (HldaTopic child in children)
            {
                if (!child.IsLeaf())
                {
                    child.CalculatePathToInternalNodes(f);
                }
            }
        }

        public bool IsLeaf()
        {
            return children.Count == 0;
        }

        public void GetAllTopics(List<HldaTopic> all)
        {
            all.Add(this);
            foreach (HldaTopic c in children)
            {
                c.GetAllTopics(all);
            }
        }

        public void GetLeaves(List<HldaTopic> all)
        {
            if (IsLeaf())
            {
                all.Add(this);
            }
            foreach (HldaTopic c in children)
            {
                c.GetLeaves(all);
            }
        }
    }
}
